package com.elnguage.lox;

import java.text.ParseException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import static com.elnguage.lox.TokenType.*;

public class Parser {
    private final List<Token> tokens;
    private int current = 0;

    Parser(List<Token> tokens)
    {
        this.tokens = tokens;
    }

    private Stmt statement()
    {
        if(match(FOR)) return forStatement();
        if(match(IF)) return ifStatement();
        if(match(PRINT))
        {
            return printStatement();
        }
        if (match(WHILE)) return whileStatement();
        if (match(LEFT_BRACE)) return new Stmt.Block(block());
        return expressionStatement();
    }
    private Stmt forStatement() 
    {
        consume(LEFT_PAREN, "Expect '(' after 'for'.");
        Stmt initializer;
        if (match(SEMICOLON)) {
          initializer = null;
        } else if (match(VAR)) {
          initializer = varDeclaration();
        } else {
          initializer = expressionStatement();
        }
        Expr condition = null;
        if (!check(SEMICOLON)) {
        condition = expression();
        }
        consume(SEMICOLON, "Expect ';' after loop condition.");
        Expr increment = null;
        if (!check(RIGHT_PAREN)) {
          increment = expression();
        }
        consume(RIGHT_PAREN, "Expect ')' after for clauses.");
        Stmt body = statement();
        if (increment != null) 
        {
            body = new Stmt.Block(
                Arrays.asList(
                    body,
                    new Stmt.Expression(increment)));
        }
        if (condition == null) condition = new Expr.Literal(true);
        body = new Stmt.While(condition, body);
        if (initializer != null) {
            body = new Stmt.Block(Arrays.asList(initializer, body));
          }
        return body;
    }

    private Stmt ifStatement() {
        consume(LEFT_PAREN, "Expect '(' after 'if'.");
        Expr condition = expression();
        consume(RIGHT_PAREN, "Expect ')' after if condition."); 
    
        Stmt thenBranch = statement();
        Stmt elseBranch = null;
        if (match(ELSE)) {
          elseBranch = statement();
        }
    
        return new Stmt.If(condition, thenBranch, elseBranch);
      }

    private Stmt printStatement()
    {
        Expr value = expression();
        consume(SEMICOLON, "Expect ';' after value");
        return new Stmt.Print(value);
    }

    private Stmt varDeclaration() {
        Token name = consume(IDENTIFIER, "Expect variable name.");
    
        Expr initializer = null;
        if (match(EQUAL)) {
          initializer = expression();
        }
    
        consume(SEMICOLON, "Expect ';' after variable declaration.");
        return new Stmt.Var(name, initializer);
    }

    private Stmt whileStatement() {
        consume(LEFT_PAREN, "Expect '(' after 'while'.");
        Expr condition = expression();
        consume(RIGHT_PAREN, "Expect ')' after condition.");
        Stmt body = statement();
    
        return new Stmt.While(condition, body);
    }

    private Stmt expressionStatement()
    {
        Expr expr = expression();
        consume(SEMICOLON, "Expect ';' after expression");
        return new Stmt.Expression(expr);
    }

    private Stmt.Function function(String kind) {
        Token name = consume(IDENTIFIER, "Expect " + kind + " name.");
        consume(LEFT_PAREN, "Expect '(' after " + kind + " name.");
        List<Token> parameters = new ArrayList<>();
        if (!check(RIGHT_PAREN)) {
          do {
            if (parameters.size() >= 255) {
              error(peek(), "Can't have more than 255 parameters.");
            }
    
            parameters.add(
                consume(IDENTIFIER, "Expect parameter name."));
          } while (match(COMMA));
        }
        consume(RIGHT_PAREN, "Expect ')' after parameters.");
        consume(LEFT_BRACE, "Expect '{' before " + kind + " body.");
        List<Stmt> body = block();
        return new Stmt.Function(name, parameters, body);
    }

    //解析事件定义的语法节点，包括事件名和事件体
    private StellarisStmt.Event stellaEvent()
    {
      Token name = consume(IDENTIFIER, "Expect a stella_event name.");
      consume(LEFT_BRACE, "Expect '{' before stella_event body");
      List<Stmt> body = block();
      return new StellarisStmt.Event(name, body);
    }

    //解析句柄定义的语法节点，包括句柄名和句柄体
    private StellarisStmt.Handle stellaHandle()
    {
      Token name = consume(IDENTIFIER, "Expect a stella_handle name.");
      consume(LEFT_BRACE, "Expect '{' before stella_handle body");
      List<Stmt> body = block();
      return new StellarisStmt.Handle(name, body);      
    }

    private List<Stmt> block() {
        List<Stmt> statements = new ArrayList<>();
    
        while (!check(RIGHT_BRACE) && !isAtEnd()) {
          statements.add(declaration());
        }
    
        consume(RIGHT_BRACE, "Expect '}' after block.");
        return statements;
      }

    private Expr assignment() {
        Expr expr = or();
    
        if (match(EQUAL)) {
          Token equals = previous();
          Expr value = assignment();
    
          if (expr instanceof Expr.Variable) {
            Token name = ((Expr.Variable)expr).name;
            return new Expr.Assign(name, value);
          }
    
          error(equals, "Invalid assignment target."); 
        }

        if(match(STELLA_OCP))
        {
          Token country = previous();
          Expr target = assignment();
        }

        if(match(STELLA_EVENT_CAT))
        {
          Token cats = previous();
          Expr value = assignment();

          if(expr instanceof Expr.Variable && value instanceof Expr.Variable)
          {
            return new Expr.EventCat(expr, value);
          }
          error(cats, "Invalid assignment target."); 
        }
        
        return expr;
      }
      private Expr or() {
        Expr expr = and();
    
        while (match(OR)) {
          Token operator = previous();
          Expr right = and();
          expr = new Expr.Logical(expr, operator, right);
        }
    
        return expr;
      }

      private Expr and() {
        Expr expr = equality();
    
        while (match(AND)) {
          Token operator = previous();
          Expr right = equality();
          expr = new Expr.Logical(expr, operator, right);
        }
    
        return expr;
      }

    private Expr expression()
    {
        return assignment();
    }

    private Stmt declaration() {
        try {
          if (match(FUN)) return function("function");
          if (match(VAR)) return varDeclaration();
          return statement();
        } catch (ParseError error) {
          synchronize();
          return null;
        }
    }

    private BaseStmt base_declaration()
    {
      try
      {
        if(check(STELLA_EVENT) || check(STELLA_HANDLE))
        {
          StellarisStmt stellarisStmt = stella_declaration();
          return new BaseStmt.StellaStmt(stellarisStmt);
        }
        else
        {
          Stmt statement = declaration();
          return new BaseStmt.BasicStmt(statement);         
        }
      }catch (ParseError error) {
          synchronize();
          return null;
        }
    }
    
    private StellarisStmt stella_declaration()
    {
      try
      {
          if(match(STELLA_EVENT)) return stellaEvent();
          if(match(STELLA_HANDLE)) return stellaHandle();
          return null;
      }
      catch (ParseError error) 
      {
          synchronize();
          return null;
      }
    }

    private Expr equality()
    {
        Expr expr = comparison();

        while (match(BANG_EQUAL, EQUAL_EQUAL))
        {
            Token operator = previous();
            Expr right = comparison();
            expr = new Expr.Binary(expr, operator, right);
        }

        return expr;
    }

    private Expr comparison()
    {
        Expr expr = term();
        
        while (match(GREATER, GREATER_EQUAL, LESS, LESS_EQUAL))
        {
            Token operator = previous();
            Expr right = term();
            expr = new Expr.Binary(expr, operator, right);
        }
        return expr;
    }

    private Expr term()
    {
        Expr expr = factor();
        
        while (match(PLUS, MINUS))
        {
            Token operator = previous();
            Expr right = factor();
            expr = new Expr.Binary(expr, operator, right);
        }
        return expr;
    }

    private Expr factor()
    {
        Expr expr = unary();
        
        while (match(STAR, SLASH))
        {
            Token operator = previous();
            Expr right = unary();
            expr = new Expr.Binary(expr, operator, right);
        }
        return expr;
    }

    private Expr unary()
    {
        if(match(BANG, MINUS))
        {
            Token operator = previous();
            Expr right = unary();
            return new Expr.Unary(operator, right);
        }
        else
        {
            return call();
        }
    }

    private Expr call() {
        Expr expr = primary();
    
        while (true) { 
          if (match(LEFT_PAREN)) {
            expr = finishCall(expr);
          } else {
            break;
          }
        }
    
        return expr;
    }

    private Expr finishCall(Expr callee) {
        List<Expr> arguments = new ArrayList<>();
        if (!check(RIGHT_PAREN)) {
          do {
            if (arguments.size() >= 255) {
                error(peek(), "Can't have more than 255 arguments.");
            }
            arguments.add(expression());
          } while (match(COMMA));
        }
    
        Token paren = consume(RIGHT_PAREN,
                              "Expect ')' after arguments.");
    
        return new Expr.Call(callee, paren, arguments);
    }

    private Expr primary()
    {
        if(match(FALSE))
        {
            return new Expr.Literal(false);
        }
        if(match(TRUE))
        {
            return new Expr.Literal(true);
        }
        if(match(NIL))
        {
            return new Expr.Literal(null);
        }

        if(match(NUMBER, STRING))
        {
            return new Expr.Literal(previous().literal);
        }

        if (match(IDENTIFIER)) {
            return new Expr.Variable(previous());
        }

        if(match(LEFT_PAREN))
        {
            Expr expr = expression();
            consume(RIGHT_PAREN, "Expect ')' after expression");
            return new Expr.Grouping(expr);
        }

        throw error(peek(), "Expect expression.");
    }
    private Token consume(TokenType type, String message)
    {
        if(check(type))
        {
            return advance();
        }
        throw error(peek(), message);
    }

    private boolean match(TokenType... types)
    {
        for(TokenType type : types)
        {
            if(check(type)){
                advance();
                return true;
            }
        }
        return false;
    }

    private boolean check(TokenType type)
    {
        if(isAtEnd())   return false;
        return tokens.get(current).type == type;
    }

    private Token advance()
    {
        if(!isAtEnd())
        {
            current++;
        }
        return previous();
    }

    private Token previous()
    {
        return tokens.get(current - 1);
    }

    private Token peek()
    {
        return tokens.get(current);
    }

    private boolean isAtEnd()
    {
        return peek().type == EOF;
    }

    private ParseError error(Token token, String message)
    {
        Lox.error(token, message);
        return new ParseError();
    }

    private void synchronize()
    {
        advance();
        while(!isAtEnd())
        {
            if(previous().type == SEMICOLON)    return;
            switch (peek().type)
            {
                case CLASS:
                case FUN:
                case VAR:
                case FOR:
                case IF:
                case WHILE:
                case PRINT:
                case RETURN:
                    return;
                default:
            }
            advance();
        }

    }

    private static class ParseError extends RuntimeException{}
    List<BaseStmt> parse()
    {
        List<BaseStmt> statements = new ArrayList<>();
        try{
            while(!isAtEnd())
            {
                statements.add(base_declaration());
            }
            return statements;
        }
        catch(ParseError error)
        {
            System.out.println(error.getMessage());
            return null;
        }
    }
}
